import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self):
        """
        Identity plasticity: no correction to deformation gradient.
        """
        super().__init__()

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        # No plastic correction
        return F  # (B, 3, 3)


class ElasticityModel(nn.Module):

    def __init__(self,
                 youngs_modulus_log: float = 8.37,
                 poissons_ratio: float = 0.49):
        """
        Corotated elasticity with trainable parameters.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio (float): Poisson's ratio (clamped [0,0.49]).
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))
        self.poissons_ratio = nn.Parameter(torch.tensor(poissons_ratio))

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor from deformation gradient tensor.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.shape[0]

        # Physical parameters
        E = self.youngs_modulus_log.exp()  # scalar
        nu = torch.clamp(self.poissons_ratio, 0.0, 0.49)  # scalar

        # Lamé parameters
        mu = E / (2.0 * (1.0 + nu))  # scalar
        la = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))  # scalar

        # SVD of F: U, Sigma, Vh such that F = U @ diag(Sigma) @ Vh
        U, sigma, Vh = torch.linalg.svd(F)  # U: (B,3,3), sigma: (B,3), Vh: (B,3,3)
        sigma = torch.clamp_min(sigma, 1e-5)  # (B,3) ensure positivity

        # Rotation R = U @ Vh
        R = torch.matmul(U, Vh)  # (B,3,3)

        # Corotated stress part: tau_c = 2*mu*(F - R) @ F^T
        Ft = F.transpose(1, 2)  # (B,3,3)
        tau_c = 2.0 * mu * torch.matmul(F - R, Ft)  # (B,3,3)

        # Volumetric part: tau_v = lambda * J * (J - 1) * I
        J = torch.prod(sigma, dim=1).view(B, 1, 1)  # (B,1,1)
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)
        tau_v = la * J * (J - 1) * I  # (B,3,3)

        # Kirchhoff stress
        kirchhoff_stress = tau_c + tau_v  # (B,3,3)

        return kirchhoff_stress
